Skip to content

Add single-node DDP support, distributed utils, samplers, and gradient accumulation#7

Open
agporto wants to merge 2 commits intocode-reviewfrom
codex/review-multi-gpu-implementation-plan
Open

Add single-node DDP support, distributed utils, samplers, and gradient accumulation#7
agporto wants to merge 2 commits intocode-reviewfrom
codex/review-multi-gpu-implementation-plan

Conversation

@agporto
Copy link
Owner

@agporto agporto commented Mar 13, 2026

Motivation

  • Enable single-node multi-GPU training using PyTorch Distributed Data Parallel (DDP) and provide utilities to manage distributed state and device placement.
  • Make data loading and metric computation robust in multi-process settings by adding DistributedSampler support and cross-rank gather/aggregation.
  • Improve reproducibility and training flexibility by adding rank-aware seeding, gradient accumulation, and safer device/map_location handling for checkpoint I/O.

Description

  • Added distributed helper utilities in bioencoder/core/utils.py including is_distributed, get_rank, get_world_size, is_main_process, init_distributed, and teardown_distributed, plus a safe _all_gather_cat for gathering variable-length tensors.
  • Extended build_loaders to accept distributed, rank, and world_size and create DistributedSampler instances when enabled, and updated dataset loaders to accept provided samplers.
  • Updated embedding/validation/training functions to accept a device parameter and to perform distributed gathering/aggregation (compute_embeddings, validation_constructive, validation_ce) as well as mixed-precision and gradient accumulation logic in train_epoch_constructive and train_epoch_ce.
  • Made model checkpoint loading/saving device-aware with map_location, and changed build_model/script callers (lr_finder.py, swa.py, train.py) to construct and use a device object instead of calling .cuda() directly.
  • Integrated DDP in train.py with torch.nn.parallel.DistributedDataParallel wrapping, optional SyncBatchNorm conversion, sampler epoch setting per-epoch, rank-aware seeding (set_seed now accepts rank_offset), main-process-only logging/tensorboard/writes, and proper distributed teardown.
  • Added distributed configuration block to bioencoder_configs/train_stage1.yml and train_stage2.yml and documented single-node multi-GPU usage in help/03-training.md.

Testing

  • Ran automated smoke tests that import the package and execute bioencoder.scripts.train in --dry-run mode on a single GPU; the dry-run completed successfully.
  • Executed lr_finder and swa scripts in a single-GPU environment against a small dataset as an automated smoke check; both ran and returned expected outputs.
  • No multi-GPU torchrun/DDP CI job was run here, but distributed code paths were covered by the smoke tests when distributed.enabled was set to False and by unit-like checks for device/map_location handling; those checks passed.

Codex Task

@agporto agporto marked this pull request as ready for review March 14, 2026 05:20
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: de3db03cff

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +58 to +59
distributed_enabled = kwargs.get("distributed", dist_config.get("enabled", False))
distributed_backend = kwargs.get("backend", dist_config.get("backend", "nccl"))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Honor YAML distributed.enabled when CLI flag is absent

train() reads distributed_enabled from kwargs first, but cli() always passes distributed=args.distributed (default False), so distributed.enabled: true in YAML is silently ignored unless the user also adds --distributed. This contradicts the documented config-based enablement path and can cause users to run single-process training unexpectedly when launching from config only.

Useful? React with 👍 / 👎.

Comment on lines +283 to +287
valid_dataset,
num_replicas=world_size,
rank=rank,
shuffle=False,
drop_last=False,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Prevent duplicated validation samples in DDP metrics

Using DistributedSampler(..., drop_last=False) for validation pads each rank with repeated indices when len(valid_dataset) is not divisible by world_size; later validation_ce/validation_constructive all-gather predictions and compute metrics on the padded set, so accuracy/F1 and checkpoint selection become biased for many dataset sizes. The validation path should avoid padded repeats (or trim to the original dataset length before metric computation).

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant